{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "from fastai.data.all import *\n",
    "from fastai.text.core import *\n",
    "from fastai.text.models.awdlstm import dropout_mask"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "from nbdev.showdoc import *"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# all_cpp"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#default_exp text.models.qrnn\n",
    "#default_cls_lvl 3"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "# QRNN\n",
    "\n",
    "> Quasi-recurrent neural networks introduced in [Bradbury et al.](https://arxiv.org/abs/1611.01576)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## ForgetMult"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export \n",
    "from torch.utils import cpp_extension\n",
    "from torch.autograd import Function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "__file__ = Path.cwd().parent/'fastai'/'text'/'models'/'qrnn.py'"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def load_cpp(name, files, path):\n",
    "    os.makedirs(Config().model/'qrnn', exist_ok=True)\n",
    "    return cpp_extension.load(name=name, sources=[path/f for f in files], build_directory=Config().model/'qrnn')"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class _LazyBuiltModule():\n",
    "    \"A module with a CPP extension that builds itself at first use\"\n",
    "    def __init__(self, name, files): self.name,self.files,self.mod = name,files,None\n",
    "\n",
    "    def _build(self):\n",
    "        self.mod = load_cpp(name=self.name, files=self.files, path=Path(__file__).parent)\n",
    "\n",
    "    def forward(self, *args, **kwargs):\n",
    "        if self.mod is None: self._build()\n",
    "        return self.mod.forward(*args, **kwargs)\n",
    "\n",
    "    def backward(self, *args, **kwargs):\n",
    "        if self.mod is None: self._build()\n",
    "        return self.mod.backward(*args, **kwargs)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "forget_mult_cuda = _LazyBuiltModule('forget_mult_cuda', ['forget_mult_cuda.cpp', 'forget_mult_cuda_kernel.cu'])\n",
    "bwd_forget_mult_cuda = _LazyBuiltModule('bwd_forget_mult_cuda', ['bwd_forget_mult_cuda.cpp', 'bwd_forget_mult_cuda_kernel.cu'])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def dispatch_cuda(cuda_class, cpu_func, x):\n",
    "    \"Depending on `x.device` uses `cpu_func` or `cuda_class.apply`\"\n",
    "    return cuda_class.apply if x.device.type == 'cuda' else cpu_func"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "The ForgetMult gate is the quasi-recurrent part of the network, computing the following from `x` and `f`.\n",
    "``` python\n",
    "h[i+1] = x[i] * f[i] + h[i] + (1-f[i])\n",
    "```\n",
    "The initial value for `h[0]` is either a tensor of zeros or the previous hidden state."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "def forget_mult_CPU(x, f, first_h=None, batch_first=True, backward=False):\n",
    "    \"ForgetMult gate applied to `x` and `f` on the CPU.\"\n",
    "    result = []\n",
    "    dim = (1 if batch_first else 0)\n",
    "    forgets = f.split(1, dim=dim)\n",
    "    inputs =  x.split(1, dim=dim)\n",
    "    prev_h = None if first_h is None else first_h.unsqueeze(dim)\n",
    "    idx_range = range(len(inputs)-1,-1,-1) if backward else range(len(inputs))\n",
    "    for i in idx_range:\n",
    "        prev_h = inputs[i] * forgets[i] if prev_h is None else inputs[i] * forgets[i] + (1-forgets[i]) * prev_h\n",
    "        if backward: result.insert(0, prev_h)\n",
    "        else:        result.append(prev_h)\n",
    "    return torch.cat(result, dim=dim)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "`first_h` is the tensor used for the value of `h[0]` (defaults to a tensor of zeros). If `batch_first=True`, `x` and `f` are expected to be of shape `batch_size x seq_length x n_hid`, otherwise they are expected to be of shape `seq_length x batch_size x n_hid`. If `backwards=True`, the elements in `x` and `f` on the sequence dimension are read in reverse."
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def manual_forget_mult(x, f, h=None, batch_first=True, backward=False):\n",
    "    if batch_first: x,f = x.transpose(0,1),f.transpose(0,1)\n",
    "    out = torch.zeros_like(x)\n",
    "    prev = h if h is not None else torch.zeros_like(out[0])\n",
    "    idx_range = range(x.shape[0]-1,-1,-1) if backward else range(x.shape[0])\n",
    "    for i in idx_range:\n",
    "        out[i] = f[i] * x[i] + (1-f[i]) * prev\n",
    "        prev = out[i]\n",
    "    if batch_first: out = out.transpose(0,1)\n",
    "    return out\n",
    "\n",
    "x,f = torch.randn(5,3,20).chunk(2, dim=2)\n",
    "for (bf, bw) in [(True,True), (False,True), (True,False), (False,False)]:\n",
    "    th_out = manual_forget_mult(x, f, batch_first=bf, backward=bw)\n",
    "    out = forget_mult_CPU(x, f, batch_first=bf, backward=bw)\n",
    "    test_close(th_out,out)\n",
    "    h = torch.randn((5 if bf else 3), 10)\n",
    "    th_out = manual_forget_mult(x, f, h=h, batch_first=bf, backward=bw)\n",
    "    out = forget_mult_CPU(x, f, first_h=h, batch_first=bf, backward=bw)\n",
    "    test_close(th_out,out)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "text/plain": [
       "torch.Size([3, 4, 5, 0, 1, 0])"
      ]
     },
     "execution_count": null,
     "metadata": {},
     "output_type": "execute_result"
    }
   ],
   "source": [
    "x = torch.randn(3,4,5)\n",
    "x.size() + torch.Size([0,1,0])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class ForgetMultGPU(Function):\n",
    "    \"Wrapper around the CUDA kernels for the ForgetMult gate.\"\n",
    "    @staticmethod\n",
    "    def forward(ctx, x, f, first_h=None, batch_first=True, backward=False):\n",
    "        ind = -1 if backward else 0\n",
    "        (i,j) = (0,1) if batch_first else (1,0)\n",
    "        output = x.new_zeros(x.shape[0]+i, x.shape[1]+j, x.shape[2])\n",
    "        if first_h is not None:\n",
    "            if batch_first: output[:, ind] = first_h\n",
    "            else:           output[ind]    = first_h\n",
    "        else: output.zero_()\n",
    "        ctx.forget_mult = bwd_forget_mult_cuda if backward else forget_mult_cuda\n",
    "        output = ctx.forget_mult.forward(x, f, output, batch_first)\n",
    "        ctx.save_for_backward(x, f, first_h, output)\n",
    "        ctx.batch_first = batch_first\n",
    "        if backward: return output[:,:-1] if batch_first else output[:-1]\n",
    "        else:        return output[:,1:]  if batch_first else output[1:]\n",
    "\n",
    "    @staticmethod\n",
    "    def backward(ctx, grad_output):\n",
    "        x, f, first_h, output = ctx.saved_tensors\n",
    "        grad_x, grad_f, grad_h = ctx.forget_mult.backward(x, f, output, grad_output, ctx.batch_first)\n",
    "        return (grad_x, grad_f, (None if first_h is None else grad_h), None, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#hide\n",
    "#cuda\n",
    "#cpp\n",
    "def detach_and_clone(t):\n",
    "    return t.detach().clone().requires_grad_(True)\n",
    "\n",
    "x,f = torch.randn(5,3,20).cuda().chunk(2, dim=2)\n",
    "x,f = x.contiguous().requires_grad_(True),f.contiguous().requires_grad_(True)\n",
    "th_x,th_f = detach_and_clone(x),detach_and_clone(f)\n",
    "for (bf, bw) in [(True,True), (False,True), (True,False), (False,False)]:\n",
    "    th_out = forget_mult_CPU(th_x, th_f, first_h=None, batch_first=bf, backward=bw)\n",
    "    th_loss = th_out.pow(2).mean()\n",
    "    th_loss.backward()\n",
    "    out = ForgetMultGPU.apply(x, f, None, bf, bw)\n",
    "    loss = out.pow(2).mean()\n",
    "    loss.backward()\n",
    "    test_close(th_out,out, eps=1e-4)\n",
    "    test_close(th_x.grad,x.grad, eps=1e-4)\n",
    "    test_close(th_f.grad,f.grad, eps=1e-4)\n",
    "    for p in [x,f, th_x, th_f]:\n",
    "        p = p.detach()\n",
    "        p.grad = None\n",
    "    h = torch.randn((5 if bf else 3), 10).cuda().requires_grad_(True)\n",
    "    th_h = detach_and_clone(h)\n",
    "    th_out = forget_mult_CPU(th_x, th_f, first_h=th_h, batch_first=bf, backward=bw)\n",
    "    th_loss = th_out.pow(2).mean()\n",
    "    th_loss.backward()\n",
    "    out = ForgetMultGPU.apply(x.contiguous(), f.contiguous(), h, bf, bw)\n",
    "    loss = out.pow(2).mean()\n",
    "    loss.backward()\n",
    "    test_close(th_out,out, eps=1e-4)\n",
    "    test_close(th_x.grad,x.grad, eps=1e-4)\n",
    "    test_close(th_f.grad,f.grad, eps=1e-4)\n",
    "    test_close(th_h.grad,h.grad, eps=1e-4)\n",
    "    for p in [x,f, th_x, th_f]:\n",
    "        p = p.detach()\n",
    "        p.grad = None"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## QRNN"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class QRNNLayer(Module):\n",
    "    \"Apply a single layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.\"\n",
    "    def __init__(self, input_size, hidden_size=None, save_prev_x=False, zoneout=0, window=1,\n",
    "                 output_gate=True, batch_first=True, backward=False):\n",
    "        assert window in [1, 2], \"This QRNN implementation currently only handles convolutional window of size 1 or size 2\"\n",
    "        self.save_prev_x,self.zoneout,self.window = save_prev_x,zoneout,window\n",
    "        self.output_gate,self.batch_first,self.backward = output_gate,batch_first,backward\n",
    "        hidden_size = ifnone(hidden_size, input_size)\n",
    "        #One large matmul with concat is faster than N small matmuls and no concat\n",
    "        mult = (3 if output_gate else 2)\n",
    "        self.linear = nn.Linear(window * input_size, mult * hidden_size)\n",
    "        self.prevX = None\n",
    "\n",
    "    def reset(self): self.prevX = None\n",
    "\n",
    "    def forward(self, inp, hid=None):\n",
    "        y = self.linear(self._get_source(inp))\n",
    "        if self.output_gate: z_gate,f_gate,o_gate = y.chunk(3, dim=2)\n",
    "        else:                z_gate,f_gate        = y.chunk(2, dim=2)\n",
    "        z_gate.tanh_()\n",
    "        f_gate.sigmoid_()\n",
    "        if self.zoneout and self.training:\n",
    "            f_gate = f_gate * dropout_mask(f_gate, f_gate.size(), self.zoneout).requires_grad_(False)\n",
    "        z_gate,f_gate = z_gate.contiguous(),f_gate.contiguous()\n",
    "        forget_mult = dispatch_cuda(ForgetMultGPU, partial(forget_mult_CPU), inp)\n",
    "        c_gate = forget_mult(z_gate, f_gate, hid, self.batch_first, self.backward)\n",
    "        output = torch.sigmoid(o_gate) * c_gate if self.output_gate else c_gate\n",
    "        if self.window > 1 and self.save_prev_x:\n",
    "            if self.backward: self.prevX = (inp[:, :1]  if self.batch_first else inp[:1]) .detach()\n",
    "            else:             self.prevX = (inp[:, -1:] if self.batch_first else inp[-1:]).detach()\n",
    "        idx = 0 if self.backward else -1\n",
    "        return output, (c_gate[:, idx] if self.batch_first else c_gate[idx])\n",
    "\n",
    "    def _get_source(self, inp):\n",
    "        if self.window == 1: return inp\n",
    "        dim = (1 if self.batch_first else 0)\n",
    "        if self.batch_first:\n",
    "            prev = torch.zeros_like(inp[:,:1]) if self.prevX is None else self.prevX\n",
    "            if prev.shape[0] < inp.shape[0]: prev = torch.cat([prev, torch.zeros_like(inp[prev.shape[0]:, :1])], dim=0)\n",
    "            if prev.shape[0] > inp.shape[0]: prev= prev[:inp.shape[0]]\n",
    "        else:\n",
    "            prev = torch.zeros_like(inp[:1]) if self.prevX is None else self.prevX\n",
    "            if prev.shape[1] < inp.shape[1]: prev = torch.cat([prev, torch.zeros_like(inp[:1, prev.shape[0]:])], dim=1)\n",
    "            if prev.shape[1] > inp.shape[1]: prev= prev[:,:inp.shape[1]]\n",
    "        inp_shift = [prev]\n",
    "        if self.backward: inp_shift.insert(0,inp[:,1:] if self.batch_first else inp[1:])\n",
    "        else:             inp_shift.append(inp[:,:-1] if self.batch_first else inp[:-1])\n",
    "        inp_shift = torch.cat(inp_shift, dim)\n",
    "        return torch.cat([inp, inp_shift], 2)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qrnn_fwd = QRNNLayer(10, 20, save_prev_x=True, zoneout=0, window=2, output_gate=True)\n",
    "qrnn_bwd = QRNNLayer(10, 20, save_prev_x=True, zoneout=0, window=2, output_gate=True, backward=True)\n",
    "qrnn_bwd.load_state_dict(qrnn_fwd.state_dict())\n",
    "x_fwd = torch.randn(7,5,10)\n",
    "x_bwd = x_fwd.clone().flip(1)\n",
    "y_fwd,h_fwd = qrnn_fwd(x_fwd)\n",
    "y_bwd,h_bwd = qrnn_bwd(x_bwd)\n",
    "test_close(y_fwd, y_bwd.flip(1), eps=1e-4)\n",
    "test_close(h_fwd, h_bwd, eps=1e-4)\n",
    "y_fwd,h_fwd = qrnn_fwd(x_fwd, h_fwd)\n",
    "y_bwd,h_bwd = qrnn_bwd(x_bwd, h_bwd)\n",
    "test_close(y_fwd, y_bwd.flip(1), eps=1e-4)\n",
    "test_close(h_fwd, h_bwd, eps=1e-4)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "#export\n",
    "class QRNN(Module):\n",
    "    \"Apply a multiple layer Quasi-Recurrent Neural Network (QRNN) to an input sequence.\"\n",
    "    def __init__(self, input_size, hidden_size, n_layers=1, batch_first=True, dropout=0,\n",
    "                 bidirectional=False, save_prev_x=False, zoneout=0, window=None, output_gate=True):\n",
    "        assert not (save_prev_x and bidirectional), \"Can't save the previous X with bidirectional.\"\n",
    "        kwargs = dict(batch_first=batch_first, zoneout=zoneout, output_gate=output_gate)\n",
    "        self.layers = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size, save_prev_x=save_prev_x,\n",
    "                                               window=((2 if l ==0 else 1) if window is None else window), **kwargs)\n",
    "                                     for l in range(n_layers)])\n",
    "        if bidirectional:\n",
    "            self.layers_bwd = nn.ModuleList([QRNNLayer(input_size if l == 0 else hidden_size, hidden_size,\n",
    "                                                       backward=True, window=((2 if l ==0 else 1) if window is None else window),\n",
    "                                                       **kwargs) for l in range(n_layers)])\n",
    "        self.n_layers,self.batch_first,self.dropout,self.bidirectional = n_layers,batch_first,dropout,bidirectional\n",
    "\n",
    "    def reset(self):\n",
    "        \"Reset the hidden state.\"\n",
    "        for layer in self.layers: layer.reset()\n",
    "        if self.bidirectional:\n",
    "            for layer in self.layers_bwd: layer.reset()\n",
    "\n",
    "    def forward(self, inp, hid=None):\n",
    "        new_hid = []\n",
    "        if self.bidirectional: inp_bwd = inp.clone()\n",
    "        for i, layer in enumerate(self.layers):\n",
    "            inp, h = layer(inp, None if hid is None else hid[2*i if self.bidirectional else i])\n",
    "            new_hid.append(h)\n",
    "            if self.bidirectional:\n",
    "                inp_bwd, h_bwd = self.layers_bwd[i](inp_bwd, None if hid is None else hid[2*i+1])\n",
    "                new_hid.append(h_bwd)\n",
    "            if self.dropout != 0 and i < len(self.layers) - 1:\n",
    "                for o in ([inp, inp_bwd] if self.bidirectional else [inp]):\n",
    "                    o = F.dropout(o, p=self.dropout, training=self.training, inplace=False)\n",
    "        if self.bidirectional: inp = torch.cat([inp, inp_bwd], dim=2)\n",
    "        return inp, torch.stack(new_hid, 0)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "qrnn = QRNN(10, 20, 2, bidirectional=True, batch_first=True, window=2, output_gate=False)\n",
    "x = torch.randn(7,5,10)\n",
    "y,h = qrnn(x)\n",
    "test_eq(y.size(), [7, 5, 40])\n",
    "test_eq(h.size(), [4, 7, 20])\n",
    "#Without an out gate, the last timestamp in the forward output is the second to last hidden\n",
    "#and the first timestamp of the backward output is the last hidden\n",
    "test_close(y[:,-1,:20], h[2])\n",
    "test_close(y[:,0,20:], h[3])"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Export -"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Converted 00_torch_core.ipynb.\n",
      "Converted 01_layers.ipynb.\n",
      "Converted 02_data.load.ipynb.\n",
      "Converted 03_data.core.ipynb.\n",
      "Converted 04_data.external.ipynb.\n",
      "Converted 05_data.transforms.ipynb.\n",
      "Converted 06_data.block.ipynb.\n",
      "Converted 07_vision.core.ipynb.\n",
      "Converted 08_vision.data.ipynb.\n",
      "Converted 09_vision.augment.ipynb.\n",
      "Converted 09b_vision.utils.ipynb.\n",
      "Converted 09c_vision.widgets.ipynb.\n",
      "Converted 10_tutorial.pets.ipynb.\n",
      "Converted 11_vision.models.xresnet.ipynb.\n",
      "Converted 12_optimizer.ipynb.\n",
      "Converted 13_learner.ipynb.\n",
      "Converted 13a_metrics.ipynb.\n",
      "Converted 14_callback.schedule.ipynb.\n",
      "Converted 14a_callback.data.ipynb.\n",
      "Converted 15_callback.hook.ipynb.\n",
      "Converted 15a_vision.models.unet.ipynb.\n",
      "Converted 16_callback.progress.ipynb.\n",
      "Converted 17_callback.tracker.ipynb.\n",
      "Converted 18_callback.fp16.ipynb.\n",
      "Converted 19_callback.mixup.ipynb.\n",
      "Converted 20_interpret.ipynb.\n",
      "Converted 20a_distributed.ipynb.\n",
      "Converted 21_vision.learner.ipynb.\n",
      "Converted 22_tutorial.imagenette.ipynb.\n",
      "Converted 23_tutorial.transfer_learning.ipynb.\n",
      "Converted 24_vision.gan.ipynb.\n",
      "Converted 30_text.core.ipynb.\n",
      "Converted 31_text.data.ipynb.\n",
      "Converted 32_text.models.awdlstm.ipynb.\n",
      "Converted 33_text.models.core.ipynb.\n",
      "Converted 34_callback.rnn.ipynb.\n",
      "Converted 35_tutorial.wikitext.ipynb.\n",
      "Converted 36_text.models.qrnn.ipynb.\n",
      "Converted 37_text.learner.ipynb.\n",
      "Converted 38_tutorial.ulmfit.ipynb.\n",
      "Converted 40_tabular.core.ipynb.\n",
      "Converted 41_tabular.data.ipynb.\n",
      "Converted 42_tabular.model.ipynb.\n",
      "Converted 43_tabular.learner.ipynb.\n",
      "Converted 45_collab.ipynb.\n",
      "Converted 50_datablock_examples.ipynb.\n",
      "Converted 60_medical.imaging.ipynb.\n",
      "Converted 65_medical.text.ipynb.\n",
      "Converted 70_callback.wandb.ipynb.\n",
      "Converted 71_callback.tensorboard.ipynb.\n",
      "Converted 97_test_utils.ipynb.\n",
      "Converted index.ipynb.\n"
     ]
    }
   ],
   "source": [
    "#hide\n",
    "from nbdev.export import notebook2script\n",
    "notebook2script()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "jupytext": {
   "split_at_heading": true
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}
